Skip to content

Conversation

@Oleg-Goncharov
Copy link
Collaborator

Description

This PR adds a new kernel that computes dbias separately for each tensor in a group and outputs a grouped dbias tensor containing per-tensor dbias values.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added the grouped dbias kernel

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Greptile Overview

Greptile Summary

This PR adds support for computing per-tensor dbias values in grouped tensor quantization operations. The key changes include:

  • API Change: Modified all grouped quantization functions to accept NVTEGroupedTensor dbias instead of NVTETensor dbias, enabling separate dbias computation for each tensor in a group
  • New Kernel: Implemented group_reduce_dbias_kernel that reduces partial dbias values per-tensor, outputting a 2D tensor with shape [num_tensors, cols]
  • Grid Computation: Unified grid calculation to elts_total / ELTS_PER_CHUNK for all cases, removing the previous DIVUP-based logic for single tensors
  • Tail Tile Support: The unified grid computation now enforces full 128x128 tile alignment for all tensors (previously single tensors could have tail tiles)

The implementation correctly handles different shape representations (SAME_BOTH_DIMS, VARYING_FIRST_DIM, etc.) and maintains the restriction that dbias is only supported for tensors with constant last dimension.

Issues to verify:

  • The removal of tail tile support for single tensors is a breaking change (test with last_dim=160 was removed)
  • Ensure the workspace offset calculation in the new reduction kernel handles all shape representations correctly
  • The full-tile requirement should be clearly documented for users

Confidence Score: 4/5

  • This PR is generally safe to merge with minor verification needed on edge cases
  • The implementation is well-structured with proper test coverage for the new grouped dbias functionality. However, the removal of tail tile support represents a breaking change that removed a test case, and some edge case calculations in the reduction kernel should be verified for correctness across all shape representations.
  • Pay attention to group_quantize_mxfp8.cuh (grid computation change) and common.cuh (new reduction kernel with offset calculations)

Important Files Changed

Filename Overview
transformer_engine/common/include/transformer_engine/cast.h Changed dbias parameter from NVTETensor to NVTEGroupedTensor in all grouped quantization APIs to support per-tensor dbias computation
transformer_engine/common/cast/core/common.cuh Added ShapeRepresentation enum and group_reduce_dbias_kernel to compute per-tensor dbias values for grouped tensors
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh Changed grid computation to always require full 128x128 tiles (removed tail tile support for is_single_tensor), integrated grouped dbias reduction, moved ShapeRepresentation to common
tests/cpp/operator/test_cast_mxfp8_grouped.cu Updated tests to handle grouped dbias tensors with shape [num_tensors, cols], removed test case with non-128-divisible last dimension (160), refactored reference computation to compute per-tensor dbias

Sequence Diagram

sequenceDiagram
    participant User
    participant API as nvte_group_quantize_dbias
    participant Dispatch as group_quantize_bwd_helper
    participant Kernel as group_quantize_mxfp8_kernel
    participant Reduction as group_reduce_dbias_kernel
    participant Output as GroupedTensor dbias

    User->>API: Call with grouped input/output tensors
    API->>Dispatch: Forward to dispatch layer
    Dispatch->>Dispatch: Convert NVTEGroupedTensor to GroupedTensor
    Dispatch->>Kernel: Launch quantization kernel
    Note over Kernel: Compute partial dbias per 128x128 tile
    Kernel->>Kernel: Write partial results to workspace
    Kernel-->>Dispatch: Return partial dbias in workspace
    Dispatch->>Reduction: Launch group_reduce_dbias_kernel
    Note over Reduction: Reduce partial dbias per-tensor<br/>across chunk_dim_Y rows
    Reduction->>Output: Write per-tensor dbias [num_tensors, cols]
    Output-->>User: Return grouped dbias tensor
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +787 to 795
if (!is_single_tensor) {
NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS,
"Number of tensors in a group is larger than "
"the MAX number of supported descriptors (64).");
// Only full tiles supported
NVTE_CHECK(last_logical_dim % CHUNK_DIM_X == 0,
"Last dimension of a grouped tensor should be divisible by 128.");
blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X);
}
const dim3 grid(blocks);

NVTE_CHECK(elts_total % ELTS_PER_CHUNK == 0, "Only full-tile grouped tensors supported.");
const dim3 grid(elts_total / ELTS_PER_CHUNK);
const size_t block_size = THREADS_PER_CHUNK;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regression: tail tiles disallowed

group_quantize now enforces elts_total % ELTS_PER_CHUNK == 0 and sets grid(elts_total / ELTS_PER_CHUNK) for all shapes, including the is_single_tensor case (shape_rep == SAME_BOTH_DIMS || VARYING_FIRST_DIM). In the previous implementation, the is_single_tensor path used DIVUP(first_logical_dim, 128)/DIVUP(last_logical_dim, 128), so tail tiles were accepted.

If callers previously passed non-128-multiple M/K for constant-last-dim groups, this change will now hard-fail. If dropping tail support is intentional, it should be documented and covered by tests; otherwise the grid/check logic likely needs to preserve the old DIVUP behavior for is_single_tensor.

@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

Oleg-Goncharov and others added 2 commits February 12, 2026 14:57
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +809 to +810
NVTE_CHECK(elts_total % ELTS_PER_CHUNK == 0, "Only full-tile grouped tensors supported.");
const dim3 grid(elts_total / ELTS_PER_CHUNK);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new grid computation elts_total / ELTS_PER_CHUNK now requires all tensors (including is_single_tensor cases) to be exact multiples of 128x128 tiles. The old implementation used DIVUP for single tensors, allowing tail tiles. A test case with last_dim=160 was removed. Verify this breaking change is intentional and documented for users.

? (first_logical_dim / num_tensors)
: first_dims_ptr[tensor_id];

const size_t rows = tensor_rows / chunk_dim_Y;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify that tensor_rows is always divisible by chunk_dim_Y (128), otherwise this division silently truncates and skips tail row reduction.

Comment on lines +108 to +110
const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
? (tensor_id * (tensor_rows / chunk_dim_Y))
: (offsets_ptr[tensor_id] / cols / chunk_dim_Y);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify that offsets_ptr[tensor_id] / cols / chunk_dim_Y correctly computes the workspace offset for varying first dimensions.

Comment on lines +147 to +150
if (global_dim_X % CHUNK_DIM_X != 0) {
NVTE_DEVICE_ERROR(
"The grouped tensor must be divisible by 128x128 tiles without a tail tile.");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's see the performance impact of having this here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants